import numpy as np
import matplotlib.pyplot as plt
import pdb

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--d", default=100, type=int, help="input dimension")
parser.add_argument("--T", default=10000, type=int, help="number of tasks")
parser.add_argument("--n", default=10, type=int, help="per-task data size")
parser.add_argument("--n1", default=10, type=int, help="training data set size")
parser.add_argument("--sigma", default=0.0, type=float, help="additive noise")
parser.add_argument("--run", default=5, type=int, help="number of independent simulations")
args = parser.parse_args()

def sample_generate(d, n, sigma, w_star):
    X = np.random.randn(n, d)
    w = np.random.randn(d) / np.sqrt(d) + w_star
    eps = sigma * np.random.randn(n)
    y = np.dot(X, w) + eps
    return X, y

def error(d, n, n1, T, run, sigma, opt_lam, w_star):
    error1_run = 0
    error2_run = 0
    for k in range(run):
        lam = np.linspace(0.01, 3.01, 30) * n
        A1 = np.zeros((d, d, len(lam)))
        b1 = np.zeros((d, len(lam)))
        A2 = np.zeros((d, d))
        b2 = np.zeros(d)
        for j in range(T):
            Xt, yt = sample_generate(d, n, sigma, w_star)
            
            # Error computation of w_split
            if n1 > 0:
                Xt_train = Xt[0:n1, :]
                Xt_val = Xt[n1:, :]
                yt_train = yt[0:n1]
                yt_val = yt[n1:]
                Xt_train_square = np.dot(Xt_train.T, Xt_train)
                M2t = np.dot(Xt_val, np.eye(d) - np.dot(np.linalg.inv(Xt_train_square + np.eye(d) * 10000), Xt_train_square))
                A2t = np.dot(M2t.T, M2t)
                b2t = np.dot(M2t.T, yt_val - np.dot(Xt_val, np.linalg.solve(Xt_train_square + np.eye(d) * 10000, np.dot(Xt_train.T, yt_train))))
            else:
                A2t = np.dot(Xt.T, Xt)
                b2t = np.dot(Xt.T, yt) 

            A2 += A2t
            b2 += b2t
        
            # Error computation of w_nosplit and tuning lambda
            Xt_square = np.dot(Xt.T, Xt)
            error_lam = np.zeros(len(lam))
            for i, lam_value in enumerate(lam):
                M1t = np.dot(Xt, np.eye(d) - np.dot(np.linalg.inv(Xt_square + np.eye(d) * lam_value), Xt_square))
                A1t = np.dot(M1t.T, M1t)
                b1t = np.dot(M1t.T, yt - np.dot(Xt, np.linalg.solve(Xt_square + np.eye(d) * lam_value, np.dot(Xt.T, yt))))
        
                A1[:,:,i] += A1t
                b1[:,i] += b1t
                
        for i in range(len(lam)):
            w1_star = np.linalg.solve(A1[:,:,i], b1[:,i])
            error_lam[i] = np.sum((w1_star - w_star) ** 2)
        
        error1 = min(error_lam)
        
        w2_star = np.linalg.solve(A2, b2)
        error2 = np.sum((w2_star - w_star) ** 2)
    
        error1_run += error1
        error2_run += error2
    
    return T * error1_run / run, T * error2_run / run

task = np.linspace(args.T/50, args.T, 50, dtype=int)
dimension = np.linspace(args.d/30, args.d, 30, dtype=int)
task_num = len(task)
dimension_num = len(dimension)
error1_T = np.zeros(task_num)
error2_T = np.zeros_like(error1_T)
error1_d = np.zeros(dimension_num)
error2_d = np.zeros(dimension_num)

w_star = np.random.randn(args.d) / np.sqrt(args.d)
opt_lam = np.genfromtxt('Lnosplit_opt.csv', delimiter=',', skip_footer=1)

print('Varying T')
for index, T in enumerate(task):
    error1_T[index], error2_T[index] = error(args.d, args.n, args.n1, T, args.run, args.sigma, opt_lam[-1], w_star)
    print(f'T = {T}, \t error_no_split = {error1_T[index]}, \t error_split = {error2_T[index]}.')

error_T = np.reshape(np.concatenate((error1_T, error2_T)), (2, -1))
filename = f'n{args.n}_d{args.d}_n1{args.n1}_varingT.csv'
np.savetxt(filename, error_T, delimiter=',')

print('Varying d')
for index, d in enumerate(dimension):
    w_star = np.random.randn(d) / np.sqrt(d)
    error1_d[index], error2_d[index] = error(d, args.n, args.n1, 1000, args.run, args.sigma, opt_lam[index], w_star)
    print(f'd = {d}, \t error_no_split = {error1_d[index]}, \t error_split = {error2_d[index]}.')

error_d = np.reshape(np.concatenate((error1_d, error2_d)), (2, -1))
filename = f'n{args.n}_n1{args.n1}_varingd.csv'
np.savetxt(filename, error_d, delimiter=',')

plot1 = plt.figure(1)
plt.scatter(task, error1_T / task, label='No-split est. error')
plt.scatter(task, error2_T / task, label='Split est. error')
plt.plot(task, 3 / task, label='No-split ref. curve')
plt.plot(task, (args.d / (args.n - args.n1)) / task, label='Split ref. curve')
plt.yscale('log')
plt.legend()

plot2 = plt.figure(2)
plt.scatter(dimension / args.n, error1_d, label='No-split est. error')
plt.scatter(dimension / args.n, error2_d, label='Split est. error')
plt.plot(dimension / args.n, np.maximum(1, dimension/args.n), label='No-split ref. curve')
plt.plot(dimension / args.n, 1 + dimension/(args.n - args.n1), label='Split ref. curve')
plt.legend()

plt.show()
